Skip to content

Conversation

@joecummings
Copy link
Member

This PR removes the selective_log_softmax function and consolidates all log probability computation through the existing compute_logprobs function.

Changes:

  • Removed selective_log_softmax from src/forge/util/ops.py
  • Added align parameter to compute_logprobs to handle both usage patterns:
    • align=False: for pre-aligned logits (when model was called with input_ids)
    • align=True: for extracting subset from full-sequence logits
  • Updated all references to use compute_logprobs with appropriate align flag
  • Added comprehensive docstring explaining both usage patterns with examples
  • Removed test_selective_log_softmax.py (functionality now covered by compute_logprobs)

The consolidation improves code maintainability by having a single function for log probability computation, while the align parameter handles the two common usage patterns in RL training.

This PR removes the `selective_log_softmax` function and consolidates all
log probability computation through the existing `compute_logprobs` function.

Changes:
- Removed `selective_log_softmax` from src/forge/util/ops.py
- Added `align` parameter to `compute_logprobs` to handle both usage patterns:
  - align=False: for pre-aligned logits (when model was called with input_ids)
  - align=True: for extracting subset from full-sequence logits
- Updated all references to use `compute_logprobs` with appropriate align flag
- Added comprehensive docstring explaining both usage patterns with examples
- Removed test_selective_log_softmax.py (functionality now covered by compute_logprobs)

The consolidation improves code maintainability by having a single function
for log probability computation, while the align parameter handles the two
common usage patterns in RL training.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 20, 2025
@JenniferWang
Copy link
Contributor

Could you add unit test for compute_logprobs instead?

@Jack-Khuu
Copy link
Contributor

Could you add unit test for compute_logprobs instead?

https://github.com/meta-pytorch/torchforge/blob/main/tests/unit_tests/util/test_compute_logprobs.py

…r tests

- Renamed test file to better reflect that it tests ops.py functions
- Added three new tests for the align parameter in compute_logprobs:
  - test_align_parameter_false: validates align=False (pre-aligned logits)
  - test_align_parameter_true: validates align=True (slicing behavior)
  - test_align_comparison: verifies slicing logic produces correct results
@joecummings joecummings merged commit ff3290e into meta-pytorch:main Oct 21, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants